mutable struct HetBlock <: Block
    M
    steady_state_options
    solve_steady_state_options
    impulse_nonlinear_options
    solve_impulse_nonlinear_options
    impulse_linear_options
    jacobian_options
    partial_jacobian_options

    backward_fun
    name
    exogenous
    policy
    backward
    non_backward_outputs

    outputs
    M_outputs
    inputs
    internals

    #not needed yet
    #static_checks

    original_inputs
    original_outputs
    original_internals
    original_M_outputs

    hetinputs
    hetoutputs
   # process_hetinputs_hetoutputs

    backward_init

    #This function can def be improved -- hard-coded static_checks() and process_hetinputs_hetoutputs
    function HetBlock(backward_fun, exogenous, policy, backward, backward_init=nothing, hetinputs=nothing, hetoutputs = nothing)
    _M = Bijection(Dict())
    _steady_state_options = Dict()
    _solve_steady_state_options = Dict{String, Any}(
        "solver" => "",
        "solver_kwargs" => Dict(),
        "ttol" => 1e-12,
        "ctol" => 1e-9,
        "verbose" => false,
        "constrained_method" => "linear_continuation",
        "constrained_kwargs" => Dict())
    _impulse_nonlinear_options = Dict()
    _solve_impulse_nonlinear_options = Dict(
        "tol" => 1e-8,
        "maxit" => 30,
        "verbose" => true)
    _impulse_linear_options = Dict()
    _jacobian_options = Dict()
    _partial_jacobians_options = Dict()

    _backward_fun = ExtendedFunction(backward_fun)
    _name = _backward_fun.name
    _exogenous = OrderedSet(make_tuple(exogenous))
    _policy, _backward = (OrderedSet(make_tuple(x)) for x ∈ (policy, backward))
    _non_backward_outputs = setdiff( _backward_fun.outputs, _backward)

    _outputs = OrderedSet(uppercase(o) for o ∈ _non_backward_outputs)
    _M_outputs = Bijection(Dict(o => uppercase(o) for o ∈ _non_backward_outputs))
    _inputs = OrderedSet(collect(setdiff(Set(_backward_fun.inputs), Set([k * "_p" for k in _backward]))))
    _inputs = _inputs ∪ _exogenous
    _internals = OrderedSet(["D", "Dbeg"]) ∪ _exogenous ∪ _backward_fun.outputs

    _original_inputs = _inputs
    _original_outputs = _outputs
    _original_internals = _internals
    _original_M_outputs = _M_outputs

    #static_checks
  if length(_policy) > 2
        error("More than two endogenous policies in '$(_name)', not yet supported")
    end

    for pol ∈ _policy
        if !(pol ∈ _backward_fun.outputs)
            error("Policy '$(pol)' not included as output in '$(_name)'")
        end
        if isuppercase(pol[1])
            error("Policy '$(pol)' is uppercase in '$(_name)', which is not allowed")
        end
    end

    for back ∈ _backward
        if !(back * "_p" ∈ _backward_fun.inputs)
            error("Backward variable '$(back)'_p not included as argument in '$(_name)'")
        end
        if !(back ∈ _backward_fun.outputs)
            error("Backward variable '$(back)' not included as output in '$(_name)'")
        end
        if back ∈ ["d", "dbeg", "D", "Dbeg"]
            error("A backward variable is called D or Dbeg, which are reserved for the distribution")
        end
    end

#=
    for out ∈ _non_backward_outputs
        if isuppercase(out[1])
            error("Output '$(out)' is uppercase in '$(_name)', which is not allowed")
        end
    end
    #end of static_checks()
=#

    if !isnothing(hetinputs)
        hetinputs = CombinedExtendedFunction(hetinputs)
    end
    if !isnothing(hetoutputs)
        hetoutputs = CombinedExtendedFunction(hetoutputs)
    end

#This is process_hetinputs_hetoutputs hard-coded
   if !isnothing(hetoutputs)
        _inputs = _inputs ∪ setdiff(hetoutputs.inputs, _backward_fun.outputs ∪ OrderedSet(["D"]))
        _outputs = _outputs ∪ OrderedSet(uppercase(o) for o ∈ hetoutputs.outputs)
        _M_outputs = Bijection(Dict(o => uppercase(o) for o ∈ hetoutputs.outputs)) * _original_M_outputs
        _internals = _internals ∪ _inputs
    end
    if !isnothing(hetinputs)
        _inputs = _inputs ∪ hetinputs.inputs
        _inputs = setdiff(_inputs, hetinputs.outputs)
        _internals = _internals ∪ hetinputs.outputs
    end
#end of process_hetinputs_hetoutputs

    if !isnothing(backward_init)
        backward_init = ExtendedFunction(backward_init)
    end
    _backward_init = backward_init

    new(_M,
        _steady_state_options,
        _solve_steady_state_options,
        _impulse_nonlinear_options,
        _solve_impulse_nonlinear_options,
        _impulse_linear_options,
        _jacobian_options,
        _partial_jacobians_options,
        _backward_fun,
        _name,
        _exogenous,
        _policy,
        _backward,
        _non_backward_outputs,
        _outputs,
        _M_outputs,
        _inputs,
        _internals,
        _original_inputs,
        _original_outputs,
        _original_internals,
        _original_M_outputs,
        hetinputs,
        hetoutputs,
        _backward_init
        )
    end
end

function het(backward_fun; exogenous, policy, backward, backward_init=nothing, hetinputs=nothing, hetoutputs=nothing)
    return HetBlock(backward_fun, exogenous, policy, backward, backward_init, hetinputs, hetoutputs)
end

#=
"implement later"
function static_checks()
    if length(b.policy) > 2
        error("More than two endogenous policies in '$(b.name)', not yet supported")
    end

    for pol in b.policy
        if !(pol in b.backward_fun_outputs)
            error("Policy '$(pol)' not included as output in '$(b.name)'")
        end
        if isuppercase(pol[1])
            error("Policy '$(pol)' is uppercase in '$(b.name)', which is not allowed")
        end
    end

    for back in b.backward
        if !(back * "_p" in b.backward_fun_inputs)
            error("Backward variable '$(back)'_p not included as argument in '$(b.name)'")
        end
        if !(back in b.backward_fun_outputs)
            error("Backward variable '$(back)' not included as output in '$(b.name)'")
        end
        if back in ["d", "dbeg", "D", "Dbeg"]
            error("A backward variable is called D or Dbeg, which are reserved for the distribution")
        end
    end

    for out in b.non_backward_outputs
        if isuppercase(out[1])
            error("Output '$(out)' is uppercase in '$(b.name)', which is not allowed")
        end
    end
end =#

function Base.show(io::IO, b::HetBlock)
    if !isnothing(b.hetinputs)
        if !isnothing(b.hetoutputs)
            print(io, "<HetBlock '$(b.name)' with hetinput '$(b.hetinputs.name)' and with hetoutput '$(b.hetoutputs.name)'>")
        else
            print(io, "<HetBlock '$(b.name)' with hetinput '$(b.hetinputs.name)'>")
        end
    else
        print(io, "<HetBlock '$(b.name)'>")
    end
    return nothing
end

function _steady_state(b::HetBlock, calibration; backward_tol=1e-8, backward_maxit=5000, forward_tol=1e-10, forward_maxit=100_000)
    ss = extract_ss_dict(b, calibration)
    ss = update_with_hetinputs(b, ss)
    ss = initialize_backward(b, ss)

    ss = backward_steady_state(b, ss; tol=backward_tol, maxit=backward_maxit)
    Dbeg, D = forward_steady_state(b, ss; tol=forward_tol, maxit=forward_maxit)
    ss = merge!(ss, Dict("Dbeg" => Dbeg, "D" => D))

    ss = update_with_hetoutputs(b, ss)

    toreturn = b.non_backward_outputs
    if !isnothing(b.hetoutputs)
        toreturn = toreturn ∪ b.hetoutputs.outputs
    end
    aggregates = Dict(uppercase(o) => dot(D, ss[o]) for o ∈ toreturn)
    merge!(ss, aggregates)

    return SteadyStateDict(Dict(k => ss[k] for k ∈ keys(ss) if k ∉ b.internals); internals=Dict(b.name => Dict(k => ss[k] for k ∈ keys(ss) if k ∈ b.internals)))
    # need to check here if needs to be ss o keys(ss); not sure how SteadyStateDict works
end

function impulse_linear(b::HetBlock, ss, inputs, outputs, Js; h=1e-4, twosided=false)
    temp = jacobian(b, keys(inputs), outputs, transpose(inputs); h, twosided)
    return ImpulseDict(apply(temp, inputs))
    # calling function from JacobianDict
end

function _jacobian(b::HetBlock, ss, inputs, outputs; T, h=1e-4, twosided=false)
    ss = extract_ss_dict(b, ss)
    outputs = inv(b.M_outputs) * outputs

    # Step 0: preliminary processing of ss
    exog = make_exog_law_of_motion(b, ss)
    endog = make_endog_law_of_motion(b, ss)
    differentiable_backward_fun, differentiable_hetinputs, differentiable_hetoutputs = jac_backward_prelim(b, ss, h, exog, twosided)
    law_of_motion = forward_shockable(CombinedTransition([exog, endog]), ss["Dbeg"]) # fix this; not sure how CombTrans works
    exog_by_output = Dict(k => expectation_shockable(exog, ss[k]) for k ∈ (outputs ∪ b.backward))

    # Step 1 of Fake News Algo
    curlyYs, curlyDs = Dict(), Dict()
    for i ∈ inputs
        curlyYs[i], curlyDs[i] = backward_fakenews(b, i, outputs, T, differentiable_backward_fun, differentiable_hetinputs, differentiable_hetoutputs, law_of_motion, exog_by_output)
    end

    # Step 2 of Fake News Algo
    curlyPs = Dict()
    for o in outputs
        curlyPs[o] = expectation_vectors(b, ss[o], T-1, law_of_motion)
    end

    # Steps 3-4 of Fake News Algo
    F, J = Dict(), Dict()
    for o in outputs
        for i in inputs
            if !(uppercase(o) in keys(F))
                F[uppercase(o)] = Dict()
            end
            if !(uppercase(o) in keys(J))
                J[uppercase(o)] = Dict()
            end
            F[uppercase(o)][i] = build_F(curlyYs[i][o], curlyDs[i], curlyPs[o])
            J[uppercase(o)][i] = J_from_F(F[uppercase(o)][i])
        end
    end

    return JacobianDict(J; name=b.name, T=T)
end

function backward_steady_state(b::HetBlock, ss; tol=1e-8, maxit=5000)
    # Backward iteration to get steady state policies and other outcomes
    ss = copy(ss)
    exog = make_exog_law_of_motion(b, ss)

    old = Dict()
    for it ∈ 1:maxit
        for k ∈ b.backward
            ss[k * "_p"] = expectation(exog, ss[k])
            delete!(ss, k)
        end
        ss = merge(ss, b.backward_fun(ss))
        
        if (mod(it, 10) == 2) && all([within_tolerance(ss[k], old[k], tol) for k ∈ b.policy])
            break
        end
        old = merge(old, Dict(k => ss[k] for k in b.policy))
        if it == maxit
            error("No convergence of policy functions after '$(maxit)' backward iterations!")
        end
    end

    for k ∈ b.backward
        delete!(ss, k * "_p")
    end

    return ss
end

function forward_steady_state(b::HetBlock, ss; tol=1e-10, maxit=100_000)
    # Forward iteration to get steady state distribution
    exog = make_exog_law_of_motion(b, ss)
    endog = make_endog_law_of_motion(b, ss)

    Dbeg_seed = get(ss, "Dbeg", nothing)
    pi_seeds = [get(ss, k * "_seed", nothing) for k ∈ b.exogenous]

    # obtain initial Distribution D
    if isnothing(Dbeg_seed)
        # stationary distribution of each exogenous var
        pis = [stationary(exog[i], pi_seed) for (i, pi_seed) ∈ enumerate(pi_seeds)]

        # uniform distribution over endogenous vars
        endog_uniform = [fill(1/length(ss[k * "_grid"]), length(ss[k * "_grid"])) for k ∈ b.policy]

        # initialize outer product as guess
        Dbeg = outer(vcat(pis, endog_uniform))
    else
        Dbeg = Dbeg_seed
    end

    # iterate until convergence, or maxit reached
    D = forward(exog, Dbeg)
    for it in 1:maxit
        Dbeg_new = forward(endog, D)
        D_new = forward(exog, Dbeg_new)

        # check convergence every 10 iterations for efficiency purposes
        if it % 10 == 0 && within_tolerance(Dbeg, Dbeg_new, tol) # from optimized routines util
            break
        end
        Dbeg = Dbeg_new
        D = D_new

        if it == maxit
            error("No convergence after '$(maxit)' forward iterations!")
        end
    end

    # D is after exogenous shock, Dbeg before it
    return Dbeg, D
end

function forward_nonlinear(b::HetBlock, ss, individual_paths, exog_path, monotonic)
end

function backward_fakenews(b::HetBlock, input_shocked, output_list, T, differentiable_backward_fun, differentiable_hetinput, differentiable_hetoutput, law_of_motion, exog)
    # Part 1 of Fake News Algo: calculate curlyY and curlyD in response to fake news shock
    din_dict = Dict(input_shocked => 1)
    if !isnothing(differentiable_hetinput) && (input_shocked ∈ differentiable_hetinput.inputs)
        din_dict = merge(din_dict, diff(differentiable_hetinput, Dict(input_shocked => 1)))
    end

    curlyV, curlyD, curlyY = backward_step_fakenews(b, din_dict, output_list, differentiable_backward_fun, differentiable_hetoutput, law_of_motion, exog; maybe_exog_shock=true)

    # infer dimensions, initialize empty arrays, fill in comtemporaneous effect
    curlyDs = Array{Float64}(undef, (T, size(curlyD)...))
    curlyDs[1, [Colon() for i ∈ 2:ndims(curlyDs)]...] = curlyD

    curlyYs = Dict(k => Array{Float64}(undef, T) for k ∈ keys(curlyY))
    for k ∈  keys(curlyY)
        curlyYs[k][1] = curlyY[k]
    end

    for t ∈ 2:T
        curlyV, curlyDs[t, [Colon() for i ∈ 2:ndims(curlyDs)]...], curlyY = backward_step_fakenews(b, Dict(k * "_p" => v for (k, v) ∈ curlyV), output_list, differentiable_backward_fun, differentiable_hetoutput, law_of_motion, exog)
        for k ∈  keys(curlyY)
            curlyYs[k][t] = curlyY[k]
        end
    end
    return curlyYs, curlyDs
end

function expectation_vectors(block::HetBlock, o_ss, T::Int, law_of_motion::Transition)
    """ Part 2 of the fake news algorithm: calculate expectation vectors curlyE  """
    curlyEs = Array{Float64}(undef, (T, size(o_ss)...))

    # initialize with beginning-of-period expectation of steady-state policy
    curlyEs[1, [Colon() for i ∈ 2:ndims(curlyEs)]...] = demean(expectation(law_of_motion[1], o_ss))
    for t in 2:T
        # demean so that curlyEs converge to zero (in theory no effect but numerically better)
        curlyEs[t, [Colon() for i ∈ 2:ndims(curlyEs)]...] = demean(expectation(law_of_motion, curlyEs[t-1, [Colon() for i ∈ 2:ndims(curlyEs)]...]))
    end
    return curlyEs
end

function build_F(curlyYs, curlyDs, curlyEs)
    """ Part 3 of fake news algorithm: build fake news matrix from curlyY, curlyD, curlyE  """
    T = size(curlyDs, 1)
    Tpost = size(curlyEs, 1) - T + 2
    F = Array{Float64}(undef, Tpost + T - 1, T)
    F[1,:] = curlyYs
    F[2:end,:] = reshape(curlyEs, Tpost + T - 2, :) * transpose(reshape(curlyDs, T, :)) # check @ sign in orig code
    return F
end

# J_from_F
function J_from_F(F)
    """ Part 4 of fake news algorithm: recursively build Jacobian from fake news matrix  """
    J = copy(F)
    for t in 2:size(J,2)
        J[2:end, t] += J[1:end-1,t-1]
    end
    return J
end

function backward_step_fakenews(block::HetBlock, din_dict, output_list, differentiable_backward_fun, differentiable_hetoutput, law_of_motion, exog; maybe_exog_shock = false)
    """ Support for part 1 of fake news algorithm: single backward step in response to shock  """
    Dbeg, D = law_of_motion[1].Dss, law_of_motion[2].Dss

    # shock perturbs outputs
    shocked_outputs = diff(differentiable_backward_fun, din_dict)
    curlyV = Dict(k => expectation(law_of_motion[1], shocked_outputs[k]) for k ∈ block.backward)

    # if there may be a shock to exogenous processes, figure out what it is
    shocks_to_exog = maybe_exog_shock ? [get(din_dict, k, nothing) for k ∈ block.exogenous] : nothing

    # perturbation to exog and outputs--outputs affect distribution tmrw
    policy_shock = [shocked_outputs[k] for k ∈ block.policy]
    if length(policy_shock) == 1
        policy_shock = policy_shock[1]
    end
    curlyD = forward_shock(law_of_motion, [shocks_to_exog, policy_shock])

    # and also affect aggregate outcomes today
    if !isnothing(differentiable_hetoutput) && !isempty(output_list ∩ differentiable_hetoutput.outputs)
        merge!(shocked_outputs, diff(differentiable_hetoutput, Dict(shocked_outputs..., din_dict...), outputs=(differentiable_hetoutput.outputs ∩ output_list)))
    end
    curlyY = Dict(k => dot(D, shocked_outputs[k]) for k ∈ output_list)

    # add effects from perturbation to exog on beginning-of-period expectations in curlyV and curlyY
    if maybe_exog_shock
        for k ∈ keys(curlyV)
            shock = expectation_shock(exog[k], shocks_to_exog)
            if !isnothing(shock)
                curlyV[k] += shock
            end
        end
    end

    return curlyV, curlyD, curlyY
end

# jac_backward_prelim
function jac_backward_prelim(block::HetBlock, ss, h, exog, twosided)
    """ Support for part 1 of fake news algorithm: preload differentiable functions  """
    differentiable_hetinputs = nothing
    if !isnothing(block.hetinputs)
        # always use two-sided differentiation for hetinputs
        differentiable_hetinputs = differentiable(block.hetinputs, ss; h=h, twosided=true)
    end

    differentiable_hetoutputs = nothing
    if !isnothing(block.hetoutputs)
        differentiable_hetoutputs = differentiable(block.hetoutputs, ss; h=h, twosided=twosided)
    end

    ss = copy(ss)
    for k ∈ block.backward
        ss[k * "_p"] = expectation(exog, ss[k])
    end
    differentiable_backward_fun = differentiable(block.backward_fun, ss; h=h, twosided=twosided)

    return differentiable_backward_fun, differentiable_hetinputs, differentiable_hetoutputs
end

function process_hetinputs_hetoutputs(block::HetBlock, hetinputs, hetoutputs; tocopy=true)
    if tocopy
        block = deepcopy(block) # deepcopy recursive
    end

    inputs = copy(block.original_inputs)
    outputs = copy(block.original_outputs)
    internals = copy(block.original_internals)

    if !isnothing(hetoutputs)
        inputs = inputs ∪ setdiff(hetoutputs.inputs, block.backward_fun.outputs ∪ OrderedSet(["D"]))
        outputs = outputs ∪ OrderedSet(uppercase(o) for o ∈ hetoutputs.outputs)
        block.M_outputs = Bijection(Dict(o => uppercase(o) for o ∈ hetoutputs.outputs)) * block.original_M_outputs
        internals = internals ∪ hetoutputs.outputs
    end
    if !isnothing(hetinputs)
        inputs = inputs ∪ hetinputs.inputs
        inputs = setdiff(inputs, hetinputs.outputs)
        internals = internals ∪ hetinputs.outputs
    end
    block.inputs = inputs
    block.outputs = outputs
    block.internals = internals

    block.hetinputs = hetinputs
    block.hetoutputs = hetoutputs

    # comment in orig py code that says to fix consequences with a self.M @ if there is a remap

    return block
end

# add_hetinputs
function add_hetinputs(block::HetBlock, functions)
    if isnothing(block.hetinputs)
        return process_hetinputs_hetoutputs(block, CombinedExtendedFunction(functions), block.hetoutputs)
    else
        return process_hetinputs_hetoutputs(block, add(block.hetinputs, functions), block.hetoutputs)
    end
end

# remove_hetinputs
function remove_hetinputs(block::HetBlock, names)
    return process_hetinputs_hetoutputs(block, remove(block.hetinputs, names), block.hetoutputs)
end

# add_hetoutputs
function add_hetoutputs(block::HetBlock, functions)
    if isnothing(block.hetoutputs)
        return process_hetinputs_hetoutputs(block, block.hetinputs, CombinedExtendedFunction(functions))
    else
        return process_hetinputs_hetoutputs(block, block.hetinputs, add(block.hetoutputs, functions))
    end
end

# remove_hetoutputs
function remove_hetinputs_hetoutputs(block::HetBlock, names)
    return process_hetinputs_hetoutputs(block, block.hetinputs, remove(block.hetoutputs, names))
end

# update_with_hetinputs
function update_with_hetinputs(block::HetBlock, d)
    if !isnothing(block.hetinputs)
        return merge(d, block.hetinputs(d))
    else
        return d
    end
end

# update_with_hetoutputs
function update_with_hetoutputs(block::HetBlock, d)
    if !isnothing(block.hetoutputs)
        return merge(d, block.hetoutputs(d))
    else
        return d
    end
end

function extract_ss_dict(block::HetBlock, ss)
    if ss isa SteadyStateDict
        ssnew = copy(ss.toplevel)
        if haskey(ss.internals, block.name)
            ssnew = merge(ssnew, ss.internals[block.name])
            end
        return ssnew
    else
        return copy(ss)
    end
end

# initialize_backward
function initialize_backward(block::HetBlock, ss)
    if !all(haskey(ss, k) for k ∈ block.backward)
        return merge(ss, block.backward_init(ss))
    else
        return ss
    end
end

# make_exog_law_of_motion
function make_exog_law_of_motion(block::HetBlock, d)
    return CombinedTransition([Markov(d[k],i) for (i,k) in enumerate(block.exogenous)])
end

# make_endog_law_of_motion
function make_endog_law_of_motion(block::HetBlock, d; monotonic=false)
    if length(block.policy) == 1
        return lottery_1d(d[block.policy[1]], d[block.policy[1] * "_grid"]; monotonic=monotonic)
    else
        return lottery_2d(d[block.policy[1]], d[block.policy[2]], d[block.policy[1] * "_grid"], d[block.policy[2] * "_grid"]; monotonic=monotonic)
    end
end

"for later"
function _impulse_nonlinear()
end

"for later"
function backward_nonlinear()
end

"for later"
function forward_nonlinear()
end
